import json
from sklearn.model_selection import train_test_split


json_dir=r'/home/lxy/gaobao_data/db_v2.1/train&valid.json'
train_file=r'/home/lxy/gaobao_data/db_v2.1/train.json'
test_file=r'/home/lxy/gaobao_data/db_v2.1/test.json'
with open(json_dir,'r',encoding='utf-8') as f:
    data=json.load(f)
    neg=[d for d in data if d['label']==0]
    pos=[d for d in data if d['label']==1]
    print(len(neg),len(pos))
    pos_train,pos_test=train_test_split(pos,test_size=500)
    neg_train,neg_test=train_test_split(neg,test_size=500)
    train=pos_train+neg_train
    test=pos_test+neg_test
    print(len(train),len(test))
    with open(train_file,'w',encoding='utf-8') as tr:
        with open(test_file,'w',encoding='utf-8') as te:
            json.dump(train,tr,ensure_ascii=False,indent=4)
            json.dump(test,te,ensure_ascii=False,indent=4)
    

